% This file reproduces Figure 6 in the paper

close all
clear

global grid_size
global rho
global gamma_x
global gamma_s
global delta
global alpha
global c_x_1
global c_x_2
global c_s_1
global c_s_2
global x
global s
global x_dot_bar
global s_dot_bar

grid_size = 1050;
rho = 30;
delta = .1;
alpha = 2; % convexity parameter
x_dot_bar = 15-delta; % upper bound on x_dot
s_dot_bar = 10-delta; % upper bound on s_dot
error_tol = 10e-12;

% cost functions 
% c(z) = c_1*z+c_2*z^alpha

c_x_1 = 0;
c_s_1 = 0;

% Assumption 3
gamma_x = min(.4*(PDF(0)-PDF(1)),1)
gamma_s = min(.35*(PDF(0)-PDF(1)),1)

% Assumption 3
convex_weight_x = 0;
convex_weight_s = 0;
c_x_2 = (convex_weight_x*PDF(1) + (1-convex_weight_x)*min(PDF(0)-gamma_x,PDF(gamma_x)))/(alpha*delta^(alpha-1))
c_s_2 = (convex_weight_s*PDF(1) + (1-convex_weight_s)*min(PDF(0)-gamma_s,PDF(gamma_s)))/(alpha*delta^(alpha-1))

% the grid
x = linspace(0,1,grid_size)'*ones(1,grid_size);
s = ones(grid_size,1)*linspace(0,1,grid_size);

% initial value and policy functions
x_dot = s_dot_bar*ones(grid_size,grid_size);
s_dot = x_dot_bar*ones(grid_size,grid_size);
V_x = zeros(grid_size,grid_size);
V_s = zeros(grid_size,grid_size);

error = error_tol*10;

while error > error_tol
    [dV_x_over_ds,dV_x_over_dx] = gradient(V_x,1/(grid_size-1)/2,1/(grid_size-1)/2);
    [dV_s_over_ds,dV_s_over_dx] = gradient(V_s,1/(grid_size-1)/2,1/(grid_size-1)/2);
     
    x_dot_new = x_dot_max(dV_x_over_dx);
    s_dot_new = s_dot_max(dV_s_over_ds);
    
    V_x_new = CDF(x-s)+(1/rho)*(-C_x(x_dot_new)+dV_x_over_dx.*x_dot_new+dV_x_over_ds.*s_dot_new);
    V_s_new = CDF(s-x)+(1/rho)*(-C_s(s_dot_new)+dV_s_over_dx.*x_dot_new+dV_s_over_ds.*s_dot_new);
    
    error = norm(V_x_new-V_x,Inf)+norm(V_s_new-V_s,Inf)+norm(x_dot_new-x_dot,Inf)+norm(s_dot_new-s_dot,Inf);
    
    V_x = V_x_new;
    V_s = V_s_new;
    x_dot = x_dot_new;
    s_dot = s_dot_new;
end

% ad hoc smoothing
grid_size = 21;
myMeanFunction = @(block_struct) mean(block_struct.data(:));
x_dot_smooth = blockproc(x_dot, [50 50], myMeanFunction);
s_dot_smooth = blockproc(s_dot, [50 50], myMeanFunction);
x_dot = x_dot_smooth;
s_dot = s_dot_smooth;
x = linspace(0,1,grid_size)'*ones(1,grid_size);
s = ones(grid_size,1)*linspace(0,1,grid_size);

x_dot(1,:) = max(x_dot(1,:),zeros(1,grid_size));
s_dot(:,1) = max(s_dot(:,1),zeros(grid_size,1));
x_dot(grid_size,:) = min(x_dot(1,:),zeros(1,grid_size));
s_dot(:,grid_size) = min(s_dot(:,1),zeros(grid_size,1));

figure
axis equal
axis tight
hold on

streamslice(s(1,:),x(:,1),s_dot,x_dot)

axis([-0.02 1.02 -0.02 1.02])
xlabel('x')
ylabel('s')
set(get(gca,'YLabel'),'Rotation',0)
print('figure6','-depsc','-painters')